# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Handles agent transfer for LLM flow."""

from __future__ import annotations

import typing
from typing import AsyncGenerator

from typing_extensions import override

from ...agents.invocation_context import InvocationContext
from ...events.event import Event
from ...models.llm_request import LlmRequest
from ...tools.function_tool import FunctionTool
from ...tools.tool_context import ToolContext
from ...tools.transfer_to_agent_tool import transfer_to_agent
from ._base_llm_processor import BaseLlmRequestProcessor

if typing.TYPE_CHECKING:
  from ...agents import BaseAgent
  from ...agents import LlmAgent


class _AgentTransferLlmRequestProcessor(BaseLlmRequestProcessor):
  """Agent transfer request processor."""

  @override
  async def run_async(
      self, invocation_context: InvocationContext, llm_request: LlmRequest
  ) -> AsyncGenerator[Event, None]:
    from ...agents.llm_agent import LlmAgent

    if not isinstance(invocation_context.agent, LlmAgent):
      return

    transfer_targets = _get_transfer_targets(invocation_context.agent)
    if not transfer_targets:
      return

    llm_request.append_instructions([
        _build_target_agents_instructions(
            invocation_context.agent, transfer_targets
        )
    ])

    transfer_to_agent_tool = FunctionTool(func=transfer_to_agent)
    tool_context = ToolContext(invocation_context)
    await transfer_to_agent_tool.process_llm_request(
        tool_context=tool_context, llm_request=llm_request
    )

    return
    yield  # AsyncGenerator requires yield statement in function body.


request_processor = _AgentTransferLlmRequestProcessor()


def _build_target_agents_info(target_agent: BaseAgent) -> str:
  return f"""
Agent name: {target_agent.name}
Agent description: {target_agent.description}
"""


line_break = '\n'


def _build_target_agents_instructions(
    agent: LlmAgent, target_agents: list[BaseAgent]
) -> str:
  si = f"""
You have a list of other agents to transfer to:

{line_break.join([
    _build_target_agents_info(target_agent) for target_agent in target_agents
])}

If you are the best to answer the question according to your description, you
can answer it.

If another agent is better for answering the question according to its
description, call `{_TRANSFER_TO_AGENT_FUNCTION_NAME}` function to transfer the
question to that agent. When transfering, do not generate any text other than
the function call.
"""

  if agent.parent_agent:
    si += f"""
Your parent agent is {agent.parent_agent.name}. If neither the other agents nor
you are best for answering the question according to the descriptions, transfer
to your parent agent. If you don't have parent agent, try answer by yourself.
"""
  return si


_TRANSFER_TO_AGENT_FUNCTION_NAME = transfer_to_agent.__name__


def _get_transfer_targets(agent: LlmAgent) -> list[BaseAgent]:
  from ...agents.llm_agent import LlmAgent

  result = []
  result.extend(agent.sub_agents)

  if not agent.parent_agent or not isinstance(agent.parent_agent, LlmAgent):
    return result

  if not agent.disallow_transfer_to_parent:
    result.append(agent.parent_agent)

  if not agent.disallow_transfer_to_peers:
    result.extend([
        peer_agent
        for peer_agent in agent.parent_agent.sub_agents
        if peer_agent.name != agent.name
    ])

  return result
